# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import json
import re
from enum import Enum
from glob import glob
from typing import Any, Dict, List, Tuple

import numpy as np
import pandas as pd
import trimesh.transformations as tra
import yaml
from slap_manipulation.agents.slap_agent import SLAPAgent

from home_robot.agent.ovmm_agent.ovmm_agent import build_vocab_from_category_map
from home_robot.agent.ovmm_agent.pick_and_place_agent import PickAndPlaceAgent
from home_robot.core.interfaces import (
    Action,
    ContinuousEndEffectorAction,
    ContinuousNavigationAction,
    DiscreteNavigationAction,
    GeneralTaskState,
    Observations,
)
from home_robot.utils.geometry import (
    sophus2xyt,
    xyt2sophus,
    xyt_base_to_global,
    xyt_global_to_base,
)
from home_robot.utils.point_cloud import show_point_cloud
from home_robot_hw.ros.utils import matrix_to_pose_msg


# methods for accessing code-list generated by LLM
def evaluate_expression(expression, dummy_value: List[str]) -> List[str]:
    """test method to evaluate an expression returned by LLM"""
    if type(expression) == list:
        return expression
    try:
        result = eval(expression)
    except SyntaxError:
        print(f"SyntaxError: {expression}")
        result = dummy_value
    return result


def separate_into_codelist(string: str) -> List[str]:
    """separate JSON into a proper codelist to be consumed by GeneralLanguageAgent"""
    function_calls = string.split(")")
    # Remove the ')' character from each function call
    new_function_calls = []
    for call in function_calls:
        call += ")"  # add back the ')'
        call = call.strip(" \n\t!;")
        call = call.replace("\n\t", "")
        call = call.replace("!", "")
        new_function_calls.append(call)
    return new_function_calls


def get_taskplan_for_robot(steps: List[str]) -> List[dict]:
    """
    Given a query return a dictionary with code-list generated by LLM
    """
    # Define the pattern using regex
    pattern = r"(\w+)\((.*?)?(,\s*([\w']+)=\'([^\)]+)\')?\)"
    steps_table = []
    for step in steps:
        # Match the pattern and extract values
        matches = re.match(pattern, step)
        if matches is None:
            continue
        # Extract verb, object, and speed values
        verb = matches.group(1)
        noun = evaluate_expression(matches.group(2), ["dummy"])
        # speed_key = matches.group(4)
        adverb = matches.group(5)
        if adverb is not None:
            adverb = adverb.upper()
        # print(verb, noun, adverb)
        steps_table.append({"verb": verb, "noun": noun, "adverb": adverb})
    return steps_table


def get_task_plans_from_llm(
    index: int,
    json_path: str = "./data/llm_eval/real_exp_v2.json",
    input_string: str = "prediction",
) -> List[str]:
    """Reads the dataset files and return a list of task plans"""
    with open(json_path, "r") as f:
        data_dict = json.load(f)
        df = pd.json_normalize(data_dict["data"])
    steps_string = df.iloc[index][input_string]
    code_list = separate_into_codelist(steps_string)
    steps_list = get_taskplan_for_robot(code_list)
    code = get_codelist(steps_list)
    return code


def get_task_plans_from_oracle(
    index: int,
    datafile: str = "./datasets/BringXFromYSurfaceToHuman.json",
    root: str = "./datasets/",
):
    """Reads the GT dataset files and return a list of task plans"""
    if datafile == "all":
        files = glob(root + "*.json")
        dflist = []
        for file in files:
            dflist.append(pd.read_json(file))
        df = pd.concat(dflist)
    else:
        df = pd.read_json(datafile)
    assert index < len(df), f"Index {index} is out of range"
    steps_list = df.iloc[index]["steps"]
    # steps_df = pd.DataFrame.from_records(steps_list)
    code = get_codelist(steps_list)
    return code


def get_codelist(steps_list):
    """formatting code-list to respect GeneralLanguageAgent's API"""
    codelist = []
    for step in steps_list:
        codelist += [f"self.{step['verb']}({step['noun']}, obs=obs)"]
    return codelist


class GeneralLanguageAgent(PickAndPlaceAgent):
    """Derived agent from PickAndPlaceAgent to execute free-form code-list
    generated by LLM. Uses SLAPAgent to execute learnt manipulation skills and
    OVMM stack to navigate"""

    def __init__(
        self,
        cfg,
        debug: bool = True,
        task_id: int = -1,
        start_from: int = 0,
        **kwargs,
    ):
        super().__init__(
            cfg,
            min_distance_goal_cm=cfg.AGENT.PLANNER.min_distance_goal_cm,
            continuous_angle_tolerance=cfg.AGENT.PLANNER.continuous_angle_tolerance,
            **kwargs,
        )
        # Visualizations
        self.task_id = task_id
        print(f"[GeneralLanguageAgent]: {self.task_id=}")
        self.steps = []
        self.start_from = start_from
        self.state = GeneralTaskState.NOT_STARTED
        self.mode = "navigation"  # TODO: turn into an enum
        self.current_step = ""
        self.cfg = cfg
        # for testing
        self.testing = False
        self.llm_run = cfg.CORLAGENT.llm_run
        self.per_skill = cfg.CORLAGENT.per_skill
        self.verbose = cfg.CORLAGENT.verbose
        if self.llm_run and self.per_skill:
            raise RuntimeError(
                "llm_run and per_skill can't be true at the same time. llm_run means executing a taskplan generated by LLM. per_skill means executing a single skill using SLAP."
            )
        self.debug = debug
        self.dry_run = self.cfg.AGENT.dry_run
        self.slap_model = SLAPAgent(cfg, task_id=self.task_id)
        if not self.cfg.SLAP.dry_run:
            self.slap_model.load_models()
        self.num_actions_done = 0
        self._language = yaml.load(
            open(self.cfg.AGENT.language_file, "r"), Loader=yaml.FullLoader
        )
        self._task_information = yaml.load(
            open(self.cfg.AGENT.task_information_file, "r"),
            Loader=yaml.FullLoader,
        )
        self.skill_to_function = {
            "take-bottle": "take_bottle",
            "handover-to-person": "handover",
            "open-object-drawer": "open_object",
            "close-object-drawer": "close_object",
            "pour-into-bowl": "pour_into_bowl",
        }
        if self.llm_run:
            self.task_plans = get_task_plans_from_llm
        elif not self.llm_run and not self.per_skill:
            self.task_plans = {
                0: [
                    "self.goto(['bottle'], obs)",
                    "self.take_bottle(['bottle'], obs)",
                    "self.goto(['person'], obs)",
                    "self.handover(['person'], obs)",
                ],
                1: [
                    "self.pick_up(['bottle'], obs)",
                    "self.goto(['counter'], obs)",
                    "self.place(['counter'], obs)",
                    "self.goto(['drawer', 'drawer handle'], obs)",
                    "self.open_object(['drawer handle'], obs)",
                    "self.goto(['bottle'], obs)",
                    "self.take_bottle(['bottle'], obs)",
                    "self.goto(['drawer handle'], obs)",
                    "self.place(['drawer handle'], obs)",
                    "self.goto(['drawer handle'], obs)",
                    "self.close_object(['drawer handle'], obs)",
                ],
                2: [
                    "self.goto(['drawer', 'drawer handle'], obs)",
                    "self.open_object(['drawer handle',], obs)",
                    "self.goto(['drawer handle', 'lemon'], obs)",
                    "self.pick_up(['lemon'], obs)",
                    "self.goto(['table'], obs)",
                    "self.place(['table'], obs)",
                ],
                3: [
                    "self.goto(['drawer', 'drawer handle'], obs)",
                    "self.open_object(['drawer handle',], obs)",
                    "self.goto(['drawer handle', 'headphones'], obs)",
                    "self.pick_up(['headphones'], obs)",
                    "self.goto(['person'], obs)",
                    "self.handover(['person'], obs)",
                ],
                4: [
                    "self.goto(['cup'], obs)",
                    "self.pick_up(['cup'], obs)",
                    "self.goto(['bowl'], obs)",
                    "self.pour_into_bowl(['bowl'], obs)",
                    "self.goto(['basket'], obs)",
                    "self.place(['basket'], obs)",
                ],
            }
        elif self.per_skill:
            # read the EVAL.task-name list from cfg and
            # map to a function call executing that skill
            task_name = cfg.EVAL.task_name[task_id]
            object_list = cfg.EVAL.object_list[task_id]
            if not isinstance(object_list, list):
                object_list = [object_list]
            # based on above generate a function_string of form:
            # self.<function_which_maps_to_skill>([object_list], obs)
            function_string = (
                f"self.{self.skill_to_function[task_name]}({object_list}, obs)"
            )
            self.task_plans = {
                task_id: [function_string],
            }

    # ---override methods---
    def reset(self):
        """Clear internal task state and reset component agents."""
        self.state = GeneralTaskState.NOT_STARTED
        self.object_nav_agent.reset()
        if self.gaze_agent is not None:
            self.gaze_agent.reset()

    def soft_reset(self):
        """soft reset between steps of the overall plan"""
        if self.verbose:
            print("[GeneralLanguageAgent] ObjectNav reset")
        self.state = GeneralTaskState.IDLE
        self.num_actions_done = 0
        self.slap_model.reset()
        self.object_nav_agent.reset()

    def _preprocess_obs(
        self, obs: Observations, object_list: List[str]
    ) -> Observations:
        """given observations and :object_list: for current task populate the
        :obs.task_observation: dictionary so that the OVMM agent can find the
        right objects"""
        # we do not differentiate b/w obejcts or receptacles
        # everything is a semantic goal to be found
        if len(object_list) > 1:
            obs.task_observations["start_recep_goal"] = 1
            obs.task_observations["object_goal"] = 2
            obs.task_observations["object_name"] = object_list[1]
            obs.task_observations["start_recep_name"] = object_list[0]
            obs.task_observations["place_recep_goal"] = None
            obs.task_observations["place_recep_name"] = None
        else:
            obs.task_observations["place_recep_goal"] = None
            obs.task_observations["place_recep_name"] = None
            obs.task_observations["start_recep_goal"] = None
            obs.task_observations["start_recep_name"] = None
            obs.task_observations["object_goal"] = 1
            obs.task_observations["object_name"] = object_list[0]
        return obs

    def _preprocess_obs_for_place(
        self, obs: Observations, object_list: List[str]
    ) -> Observations:
        # we do not differentiate b/w obejcts or receptacles
        # everything is a semantic goal to be found
        # start_recep_goal and "end_recep_goal" are always None
        obs.task_observations["end_recep_goal"] = 1
        obs.task_observations["end_recep_name"] = None
        obs.task_observations["object_goal"] = None
        obs.task_observations["goal_name"] = object_list[0]
        obs.task_observations["start_recep_goal"] = None
        obs.task_observations["start_recep_name"] = None
        return obs

    # --unique methods--
    def skill_is_done(self) -> bool:
        """is the current skill done and agent free?"""
        return self.state == GeneralTaskState.IDLE

    def task_is_done(self) -> bool:
        """is the entire task done?"""
        return len(self.steps) == 0 and self.state == GeneralTaskState.IDLE

    def is_busy(self) -> bool:
        """is the agent currently busy?"""
        return (
            self.state == GeneralTaskState.PREPPING
            or self.state == GeneralTaskState.DOING_TASK
        )

    def get_steps(self, task: str):
        """takes in a task string and returns a list of steps to complete the task"""
        if self.testing or self.debug or self.llm_run:
            if self.llm_run:
                self.steps = self.task_plans(int(task))
            else:
                self.steps = self.task_plans[int(task)]
        else:
            raise NotImplementedError(
                "Getting plans outside of test tasks is not implemented yet"
            )
        if self.verbose:
            print("Task steps: ", self.steps)
        for i in range(self.start_from):
            self.steps.pop(0)

    def add_vocab_per_step(self):
        """based on the current step, add the vocab for that step to the semantic sensor"""
        self.object_lists = []
        for i, step in enumerate(self.steps):
            # Extracting the list of strings
            matches = re.search(r"\[([^]]+)\]", step)

            if matches:
                # Split the matched string by commas to get individual strings
                object_list = list(eval(matches.group(0)))
                print(object_list)
                self.object_lists.append(object_list)
                object_dict = {id: object for id, object in enumerate(object_list)}
                # Simple vocabulary contains only object and necessary receptacles
                simple_vocab = build_vocab_from_category_map(object_dict, {})
                self.semantic_sensor.update_vocabulary_list(simple_vocab, i)

    def goto(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """skill to goto an object"""
        if self.verbose:
            print("[LangAgent]: In locate skill")
        info = {}
        if self.skip_find_object:
            # transition to the next state
            action = DiscreteNavigationAction.STOP
            self.state = GeneralTaskState.IDLE
        else:
            if not self.is_busy():
                if self.verbose:
                    print("[LangAgent]: Changing mode, setting goals")
                self.mode = "navigation"
                if self.verbose:
                    print(f"[LangAgent]: {self.mode=}")
                self.state = GeneralTaskState.PREPPING
                info["not_viz"] = True
                info["object_list"] = object_list
                return DiscreteNavigationAction.NAVIGATION_MODE, info
            else:
                self.state = GeneralTaskState.DOING_TASK
                obs = self._preprocess_obs(obs, object_list)
                if "place" in self.steps[0] or "pick_up" in self.steps[0]:
                    if self.verbose:
                        print(
                            f"[GeneralLanguageAgent] Preparing for next task: {self.steps[0]} by decreasing rad + min_dist"
                        )
                    self.object_nav_agent.planner.min_goal_distance_cm = 50
                    self.object_nav_agent.planner.goal_dilation_selem_radius = 10
                else:
                    if self.verbose:
                        print(
                            f"[GeneralLanguageAgent] Preparing for next task: {self.steps[0]} by increasing rad + min_dist"
                        )
                    self.object_nav_agent.planner.min_goal_distance_cm = 80
                    self.object_nav_agent.planner.goal_dilation_selem_radius = 20
                action, info["viz"] = self.object_nav_agent.act(obs)
                if action == DiscreteNavigationAction.STOP or self.dry_run:
                    self.soft_reset()
                    self.state = GeneralTaskState.IDLE
        return action, info

    def pick_up(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """heuristic skill to pick up an object"""
        info = {}
        if self.verbose:
            print("[LangAgent]: In pick_up skill")
        # return following if agent currently not in manip mode
        if (
            self.state == GeneralTaskState.IDLE
            or self.state == GeneralTaskState.NOT_STARTED
        ):
            if self.verbose:
                print(
                    "[LangAgent]: Change the mode of the robot to manipulation mode; set goals"
                )
            self.mode = "manipulation"
            info["not_viz"] = True
            info["object_list"] = object_list
            self.state = GeneralTaskState.PREPPING
            return DiscreteNavigationAction.MANIPULATION_MODE, info
        else:
            if self.verbose:
                print("[LangAgent]: Picking up with heuristic", object_list, obs)
            self.state = GeneralTaskState.IDLE
            return DiscreteNavigationAction.PICK_OBJECT, None

    def place_on(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """alias for self.place"""
        return self.place(object_list, obs)

    def place(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """heuristic skill to place an object"""
        info = {}
        if self.verbose:
            print("[LangAgent]: In place skill")
        if not self.is_busy():
            self.mode = "manipulation"
            info["not_viz"] = True
            info["object_list"] = object_list
            self.state = GeneralTaskState.PREPPING
            return DiscreteNavigationAction.MANIPULATION_MODE, info
        else:
            if self.verbose:
                print(
                    "[LangAgent]: DRYRUN: Run SLAP on: place-on",
                    object_list,
                    obs,
                )
            self.state = GeneralTaskState.DOING_TASK
            # place the object somewhere - hopefully in front of the agent.
            # obs = self._preprocess_obs_for_place(obs, object_list)
            info["semantic_frame"] = obs.task_observations["semantic_frame"]
            info["object_list"] = object_list
            action, action_info = self.place_policy.forward(obs, info)
            if action == DiscreteNavigationAction.STOP:
                self.state = GeneralTaskState.IDLE
            return action, action_info

    def open(self, object_list, obs) -> Tuple[Action, Dict[str, Any]]:
        """alias for self.open_object"""
        return self.open_object(object_list, obs)

    def open_object(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """learnt skill to open an object"""
        language = "open-object-top-drawer"
        num_actions = self._task_information[language]
        return self.call_slap(language, num_actions, obs, object_list)

    def close_object(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """learnt skill to close an object"""
        language = "close-object-drawer"
        num_actions = self._task_information[language]
        return self.call_slap(language, num_actions, obs, object_list)

    def handover(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """learnt skill to handover an object to a person"""
        language = "handover-to-person"
        num_actions = self._task_information[language]
        return self.call_slap(language, num_actions, obs, object_list)

    def take(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """alias: learnt skill to take-bottle"""
        language = "take-bottle"
        num_actions = self._task_information[language]
        return self.call_slap(language, num_actions, obs, object_list)

    def take_bottle(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """learnt skill to take-bottle"""
        language = "take-bottle"
        num_actions = self._task_information[language]
        return self.call_slap(language, num_actions, obs, object_list)

    def pour_in(self, object_list, obs) -> Tuple[Action, Dict[str, Any]]:
        """alias for self.poour_into_bowl"""
        return self.pour_into_bowl(object_list, obs)

    def pour_into_bowl(
        self, object_list: List[str], obs: Observations
    ) -> Tuple[Action, Dict[str, Any]]:
        """learnt skill to pour-into-bowl"""
        language = "pour-into-bowl"
        num_actions = self._task_information[language]
        return self.call_slap(language, num_actions, obs, object_list)

    def call_slap(
        self, language: str, num_actions: int, obs, object_list: List[str]
    ) -> Tuple[Action, Dict[str, Any]]:
        """main SLAP function which takes Observations, description of
        task to be executed, predicts actions using SLAPAgent and sends it back
        as well-formed action to Env"""
        info = {}
        action = None
        obs.task_observations["task-name"] = language
        obs.task_observations["num-actions"] = num_actions
        obs.task_observations["object_list"] = object_list
        if not self.is_busy() or self.state == GeneralTaskState.PREPPING:
            if self.per_skill or self.state == GeneralTaskState.PREPPING:
                self.state = GeneralTaskState.DOING_TASK
                info["object_list"] = object_list
                print(f"[AGENT] {object_list=}")
                return DiscreteNavigationAction.MANIPULATION_MODE, info
            if self.verbose:
                print("[LangAgent]: Changing mode, setting goals")
            self.state = GeneralTaskState.PREPPING
            # rotate the obs before sending it in
            camera_pose = obs.task_observations["base_camera_pose"]
            xyz = tra.transform_points(obs.xyz.reshape(-1, 3), camera_pose)

            # PCD comes from nav mode, where robot base rot is diff
            # from that we trained on, so add another rotation
            rot_matrix = tra.euler_matrix(0, 0, -np.pi / 2)
            obs.xyz = tra.transform_points(xyz, rot_matrix)

            result, info = self.slap_model.predict(obs)

            # rotate back the predicted point
            rot_matrix = tra.euler_matrix(0, 0, np.pi / 2)
            info["interaction_point"] = tra.transform_points(
                info["interaction_point"].reshape(-1, 3), rot_matrix
            ).reshape(-1)

            # structure for tasks such that the robot can face objects head-on
            if "open-object" in language:
                info["global_offset_vector"] = np.array([0, 1, 0])
                info["global_orientation"] = np.deg2rad(-90)
                info["offset_distance"] = 0.83
            if "close-object" in language:
                info["global_offset_vector"] = np.array([0, 1, 0])
                info["global_orientation"] = np.deg2rad(-90)
                info["offset_distance"] = 0.5
            if "handover" in language:
                info["global_offset_vector"] = np.array([-1, 0, 0])
                info["global_orientation"] = np.deg2rad(0)
                info["offset_distance"] = 0.95
            if "take-bottle" == language:
                info["global_offset_vector"] = np.array([0, 1, 0])
                info["global_orientation"] = np.deg2rad(-90)
                info["offset_distance"] = 0.65
            if "pour-into-bowl" == language:
                info["global_offset_vector"] = np.array([0, 1, 0])
                info["global_orientation"] = np.deg2rad(-90)
                info["offset_distance"] = 0.65
            projected_point = np.copy(info["interaction_point"])
            projected_point[2] = 0
            info["SLAP"] = True
            action = ContinuousNavigationAction(projected_point)
            self.slap_model.reset()
            return action, info
        else:
            # rotate the obs before sending it in
            camera_pose = obs.task_observations["base_camera_pose"]
            obs.xyz = tra.transform_points(obs.xyz.reshape(-1, 3), camera_pose)
            result, info = self.slap_model.predict(obs)
            if result is not None:
                action = ContinuousEndEffectorAction(
                    result[:, :3],
                    result[:, 3:7],
                    np.expand_dims(result[:, 7], -1),
                )
            else:
                action = ContinuousEndEffectorAction(
                    np.random.rand(1, 3),
                    np.random.rand(1, 4),
                    np.random.rand(1, 1),
                )
            self.soft_reset()
            self.state = GeneralTaskState.IDLE
            return action, info

    def act(self, obs: Observations, task: str) -> Tuple[Action, Dict[str, Any]]:
        """high-level act method which generates plan and then calls the appropriate
        skill based on planned list in :self.steps:"""
        if self.state == GeneralTaskState.NOT_STARTED and len(self.steps) == 0:
            self.get_steps(task)
            self.add_vocab_per_step()
            self.current_step_id = -1
        if not self.is_busy():
            if self.verbose:
                print(f"[LangAgent]: {self.state=}")
            self.current_step = self.steps.pop(0)
            self.current_step_id += 1
            self.current_object_list = self.object_lists[self.current_step_id]
            if self.verbose:
                print(
                    f"[LangAgent] {self.current_step=}, {self.current_step_id=}, {self.current_object_list=}"
                )
            if self.semantic_sensor.current_vocabulary_id != self.current_step_id:
                self.semantic_sensor.set_vocabulary(self.current_step_id)
        if self.verbose:
            print(
                f"[LangAgent]: evaling: {self.current_step=}, {self.current_step_id=}"
            )
        if self.config.GROUND_TRUTH_SEMANTICS == 0:
            if (
                self.semantic_sensor is not None
                and self.semantic_sensor.current_vocabulary_id is not None
            ):
                obs = self._preprocess_obs(obs, self.current_object_list)
                obs = self.semantic_sensor(obs)
        else:
            obs.task_observations["semantic_frame"] = None
        action, info = eval(self.current_step)
        return action, info
